{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "02153.ipynb",
      "provenance": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/yananma/5_programs_per_day/blob/master/02153.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TS8BpDlEBfK2",
        "colab_type": "text"
      },
      "source": [
        "## 6.4 循环神经网络的从零开始实现"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sPquFyCtCNIJ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch \n",
        "import math \n",
        "import numpy as np \n",
        "from torch import nn, optim \n",
        "import torch.nn.functional as F \n",
        "import d2l \n",
        "import time \n",
        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dcXGC6x_FAGV",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!mkdir ../../data"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "S4C6PVOjFdm5",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 101
        },
        "outputId": "6e694634-96a9-40b6-8feb-66f8b1590ecc"
      },
      "source": [
        "!git clone https://github.com/ShusenTang/Dive-into-DL-PyTorch.git"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Cloning into 'Dive-into-DL-PyTorch'...\n",
            "remote: Enumerating objects: 1692, done.\u001b[K\n",
            "remote: Total 1692 (delta 0), reused 0 (delta 0), pack-reused 1692\u001b[K\n",
            "Receiving objects: 100% (1692/1692), 25.29 MiB | 16.89 MiB/s, done.\n",
            "Resolving deltas: 100% (975/975), done.\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "CnGzwIzXFybH",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!cp Dive-into-DL-PyTorch/data/jaychou_lyrics.txt.zip ../../data"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9ey3ETiwFqAE",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "(corpus_indices, char_to_idx, idx_to_char, vocab_size) = d2l.load_data_jay_lyrics()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BpjQitoNF7Td",
        "colab_type": "text"
      },
      "source": [
        "### 6.4.1 one-hot 向量"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9nWdgKhdGESD",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 50
        },
        "outputId": "eea7e0f4-ba01-4bb8-d105-b5086c130cd8"
      },
      "source": [
        "def one_hot(x, n_class, dtype=torch.float32):\n",
        "    x = x.long()\n",
        "    res = torch.zeros(x.shape[0], n_class, dtype=dtype, device=x.device)\n",
        "    res.scatter_(1, x.view(-1, 1), 1)\n",
        "    return res \n",
        "\n",
        "x = torch.tensor([0, 2])\n",
        "one_hot(x, vocab_size)"
      ],
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([[1., 0., 0.,  ..., 0., 0., 0.],\n",
              "        [0., 0., 1.,  ..., 0., 0., 0.]])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 11
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NH77RlaVGhJ5",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "1ea8cc37-4297-4a82-b5b1-182572f19a18"
      },
      "source": [
        "def to_onehot(X, n_class):\n",
        "    return [one_hot(X[:, i], n_class) for i in range(X.shape[1])]\n",
        "\n",
        "X = torch.arange(10).view(2, 5)\n",
        "inputs = to_onehot(X, vocab_size)\n",
        "len(inputs), inputs[0].shape"
      ],
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(5, torch.Size([2, 1027]))"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 12
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n454jJieG7KN",
        "colab_type": "text"
      },
      "source": [
        "### 6.4.2 初始化模型参数"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "s3Mm-bFAHBd6",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "92586dbb-17ef-4c09-e302-9c5d5fd576f5"
      },
      "source": [
        "num_inputs, num_hiddens, num_outputs = vocab_size, 256, vocab_size \n",
        "print('will use', device)\n",
        "\n",
        "def get_params():\n",
        "    def _one(shape):\n",
        "        ts = torch.tensor(np.random.normal(0, 0.01, size=shape), device=device, dtype=torch.float32)\n",
        "        return torch.nn.Parameter(ts, requires_grad=True)\n",
        "\n",
        "    W_xh = _one((num_inputs, num_hiddens))\n",
        "    W_hh = _one((num_hiddens, num_hiddens))\n",
        "    b_h = torch.nn.Parameter(torch.zeros(num_hiddens, device=device, requires_grad=True))\n",
        "    W_hq = _one((num_hiddens, num_outputs))\n",
        "    b_q = torch.nn.Parameter(torch.zeros(num_outputs, device=device, requires_grad=True))\n",
        "    return nn.ParameterList([W_xh, W_hh, b_h, W_hq, b_q])"
      ],
      "execution_count": 23,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "will use cuda\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3g6Cbm6iIKrZ",
        "colab_type": "text"
      },
      "source": [
        "### 6.4.3 定义模型"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "r85Yj07TIPQ1",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def init_rnn_state(batch_size, num_hiddens, device):\n",
        "    return (torch.zeros((batch_size, num_hiddens), device=device), )"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MQIjLWtwIdpv",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def rnn(inputs, state, params):\n",
        "    W_xh, W_hh, b_h, W_hq, b_q = params \n",
        "    H, = state \n",
        "    outputs = []\n",
        "    for X in inputs:\n",
        "        H = torch.tanh(torch.matmul(X, W_xh) + torch.matmul(H, W_hh) + b_h)\n",
        "        Y = torch.matmul(H, W_hq) + b_q\n",
        "        outputs.append(Y)\n",
        "    return outputs, (H, )"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "m73mJc3_JI_L",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "06e0a538-a9c7-41d3-de02-64518a846bd2"
      },
      "source": [
        "state = init_rnn_state(X.shape[0], num_hiddens, device)\n",
        "inputs = to_onehot(X.to(device), vocab_size)\n",
        "params = get_params()\n",
        "outputs, state_new = rnn(inputs, state, params)\n",
        "len(outputs), outputs[0].shape, state_new[0].shape"
      ],
      "execution_count": 27,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(5, torch.Size([2, 1027]), torch.Size([2, 256]))"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 27
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ovJkVLa-Jkx_",
        "colab_type": "text"
      },
      "source": [
        "### 6.4.4 定义预测函数"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZjwMDittKa-F",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def predict_rnn(prefix, num_chars, rnn, params, init_rnn_state, \n",
        "        num_hiddens, vocab_size, device, idx_to_char, char_to_idx):\n",
        "    state = init_rnn_state(1, num_hiddens, device) \n",
        "    output = [char_to_idx[prefix[0]]]\n",
        "    for t in range(num_chars + len(prefix) - 1):\n",
        "        X = to_onehot(torch.tensor([[output[-1]]], device=device), vocab_size)\n",
        "        (Y, state) = rnn(X, state, params)\n",
        "        if t < len(prefix) - 1:\n",
        "            output.append(char_to_idx[prefix[t + 1]])\n",
        "        else:\n",
        "            output.append(int(Y[0].argmax(dim=1).item()))\n",
        "    return ''.join([idx_to_char[i] for i in output])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sBE5szROLonf",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "5b8db49b-5fab-4838-8d00-bddd61d2cb3f"
      },
      "source": [
        "predict_rnn('分开', 10, rnn, params, init_rnn_state, num_hiddens, vocab_size, device, idx_to_char, char_to_idx)"
      ],
      "execution_count": 29,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'分开吴苍狠鼠苍狠鼠苍狠鼠'"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 29
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Up0tigUiL4UO",
        "colab_type": "text"
      },
      "source": [
        "### 6.4.5 裁剪梯度"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JQNberbkMD5S",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def grad_clipping(params, theta, device):\n",
        "    norm = torch.tensor([0.0], device=device)\n",
        "    for param in params:\n",
        "        norm += (param.grad.data ** 2).sum()\n",
        "    norm = norm.sqrt().item()\n",
        "    if norm > theta:\n",
        "        for param in params:\n",
        "            param.grad.data *= (theta / norm)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "K92QLFzTMfNa",
        "colab_type": "text"
      },
      "source": [
        "### 6.4.7 定义模型训练函数"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GmeHbH9LMrPi",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def train_and_predict_rnn(rnn, get_params, init_rnn_state, num_hiddens, vocab_size, device, corpus_indices, idx_to_char, \n",
        "            char_to_idx, is_random_iter, num_epochs, num_steps, lr, clipping_theta, batch_size, pred_peroid, \n",
        "            pred_len, prefixes):\n",
        "    if is_random_iter:\n",
        "        data_iter_fn = d2l.data_iter_random \n",
        "    else:\n",
        "        data_iter_fn = d2l.data_iter_consecutive \n",
        "    params = get_params()\n",
        "    loss = nn.CrossEntropyLoss()\n",
        "\n",
        "    for epoch in range(num_epochs):\n",
        "        if not is_random_iter:\n",
        "            state = init_rnn_state(batch_size, num_hiddens, device)\n",
        "        l_sum, n, start = 0.0, 0, time.time()\n",
        "        data_iter = data_iter_fn(corpus_indices, batch_size, num_steps, device)\n",
        "        for X, Y in data_iter:\n",
        "            if is_random_iter:\n",
        "                state = init_rnn_state(batch_size, num_hiddens, device)\n",
        "            else:\n",
        "                for s in state:\n",
        "                    s.detach_()\n",
        "\n",
        "            inputs = to_onehot(X, vocab_size)\n",
        "            (outputs, state) = rnn(inputs, state, params)\n",
        "            outputs = torch.cat(outputs, dim=0)\n",
        "            y = torch.transpose(Y, 0, 1).contiguous().view(-1)\n",
        "            l = loss(outputs, y.long())\n",
        "\n",
        "            if params[0].grad is not None:\n",
        "                for param in params:\n",
        "                    param.grad.data.zero_()\n",
        "            l.backward()\n",
        "            grad_clipping(params, clipping_theta, device)\n",
        "            d2l.sgd(params, lr, 1)\n",
        "            l_sum += l.item() * y.shape[0]\n",
        "            n += y.shape[0]\n",
        "\n",
        "        if (epoch + 1) % pred_period == 0:\n",
        "            print('epoch %d, perplexity %f, time %.2f sec' \n",
        "                % (epoch + 1, math.exp(l_sum / n), time.time() - start))\n",
        "            for prefix in prefixes:\n",
        "                print(' -', predict_rnn(prefix, pred_len, rnn, params, init_rnn_state, \n",
        "                   num_hiddens, vocab_size, device, idx_to_char, char_to_idx))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Y5tAk0TbSn4E",
        "colab_type": "text"
      },
      "source": [
        "### 6.4.8 训练模型并创作歌词"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "z578QNyIVNDM",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "num_epochs, num_steps, batch_size, lr, clipping_theta = 250, 35, 32, 1e2, 1e-2 \n",
        "pred_period, pred_len, prefixes = 50, 50, ['分开', '不分开']"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DU3p4UmiVhWv",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 269
        },
        "outputId": "bfc65917-b314-45c7-8c74-07b822abc706"
      },
      "source": [
        "train_and_predict_rnn(rnn, get_params, init_rnn_state, num_hiddens, vocab_size, device, corpus_indices, idx_to_char, \n",
        "        char_to_idx, True, num_epochs, num_steps, lr, clipping_theta, batch_size, pred_period, pred_len, prefixes)"
      ],
      "execution_count": 44,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "epoch 50, perplexity 68.871115, time 0.09 sec\n",
            " - 分开 我想要这 你使了 别怪我 娘子就 一直么 我想要的可爱女人 坏坏的让我疯狂的可爱女人 坏坏的让我疯\n",
            " - 不分开 我想要再 你小了 别怪我 泪不 让我的红  谁的让我 狂的可爱女人 一坏的让我疯狂的可爱女人 坏坏\n",
            "epoch 100, perplexity 10.213086, time 0.09 sec\n",
            " - 分开 一直两它一步四步 连成线背著背 我怀在我生朋友样外加找 你十的让我面红 可小就耳濡路染 我想要你的\n",
            " - 不分开吗 我不要和你离着我的手不 铺原的外板人  爱穿了的话剩 像 在我胸口睡久 像这样的生活 我的你 你\n",
            "epoch 150, perplexity 2.783084, time 0.09 sec\n",
            " - 分开 一步两它开 谁让它停留的 为什么我女朋友场外加油 你却还让我出糗 从小就耳濡路染 什么刀枪了棍棒 \n",
            " - 不分开扫 然后将过了 慢慢温在 全暖了空屋 白色蜡烛 温暖了空屋 白色蜡烛 温暖了空屋 白色蜡烛 温暖了空\n",
            "epoch 200, perplexity 1.637130, time 0.09 sec\n",
            " - 分开 一步在它因不达米亚平  养的我猫坦口让记 不有再这 你想我有多的菸 一小恩怨 在上耿中  一场梦 \n",
            " - 不分开期 然后将过去 慢慢温习 全家了空出 白色蜡烛 温暖了空屋 白色蜡烛 温暖了空屋 白色蜡烛 温暖了空\n",
            "epoch 250, perplexity 1.302495, time 0.09 sec\n",
            " - 分开 一直令它留 谁让它停留的 为什么我女朋友场外加油 你却还让我出糗 可小就迷了路怎么找也找不着 心血\n",
            " - 不分开扫 我叫你爸 你打我妈 这样对吗干嘛这样 何必让酒牵鼻子走 瞎 说说你打单车一个 想想你说样来我单妈\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "My1bkmf7V7_0",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 269
        },
        "outputId": "9d004f13-dd30-444c-fb89-95bcc9ab6308"
      },
      "source": [
        "train_and_predict_rnn(rnn, get_params, init_rnn_state, num_hiddens, vocab_size, device, corpus_indices, idx_to_char, \n",
        "        char_to_idx, False, num_epochs, num_steps, lr, clipping_theta, batch_size, pred_period, pred_len, prefixes)"
      ],
      "execution_count": 45,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "epoch 50, perplexity 60.996227, time 0.09 sec\n",
            " - 分开 我想要再 你有 我想要再不 我不 再不 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想\n",
            " - 不分开 我想要你的溪 我不要这 我有了空 你不了 别怪我 一颗我 一颗我 一颗我 一颗我 一颗我 一颗我 \n",
            "epoch 100, perplexity 7.106263, time 0.09 sec\n",
            " - 分开 一场的假旧在 帅呆了我 全场了人 习果用直在过 哼哼哈兮 快使用双截棍 哼哼哈兮 快使用双截棍 哼\n",
            " - 不分开觉 你已经抽 我谁了有节 我一定空 装我了外的溪边 我默店的手色在西元前 深埋在美索不达米亚平  姥\n",
            "epoch 150, perplexity 2.100979, time 0.10 sec\n",
            " - 分开 一养堂 你在我 说你球 干沉空 停子病一信代 唱有人看着 是让它人鸦的你爱下 在够翻文引刻 所有在\n",
            " - 不分开觉 你已经离开我 不知不觉 我跟了这节奏 后知后觉 又过了一个秋 后知后兮 快使用双截棍 哼哼哈兮 \n",
            "epoch 200, perplexity 1.299248, time 0.09 sec\n",
            " - 分开 是你已 你打堂枪 回止看 在一句珍停留 一直在停留 谁让它 一句惹毛三会 不着人停留 谁让它 一烧\n",
            " - 不分开觉 你已经离开我 不知不觉 我跟了这节奏 后知后觉 又过了一个秋 后知后觉 我该好好生活 我该好好生\n",
            "epoch 250, perplexity 1.160078, time 0.09 sec\n",
            " - 分开 问候就这样打是开 回到儿回爱打我妈妈 难道你手不会痛吗 其实我回家就想要阻止一切 让家庭回到过去甜\n",
            " - 不分开觉 你已经离开我 不知不觉 我跟了这节奏 后知后觉 又过了一个秋 后知后觉 我该好好生活 我该好好生\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}